{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MONAI version: 0.0.1\n",
      "Python version: 3.8.1 (default, Jan  8 2020, 22:29:32)  [GCC 7.3.0]\n",
      "Numpy version: 1.18.1\n",
      "Pytorch version: 1.4.0\n",
      "Ignite version: 0.3.0\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "from ignite.engine import create_supervised_trainer\n",
    "from ignite.engine.engine import Events\n",
    "from ignite.handlers import ModelCheckpoint\n",
    "\n",
    "import monai\n",
    "\n",
    "\n",
    "monai.config.print_config()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "State:\n",
       "\titeration: 4\n",
       "\tepoch: 2\n",
       "\tepoch_length: 2\n",
       "\tmax_epochs: 2\n",
       "\toutput: 26133.853515625\n",
       "\tbatch: <class 'tuple'>\n",
       "\tmetrics: <class 'dict'>\n",
       "\tdataloader: <class 'generator'>\n",
       "\tseed: 12"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "lr = 1e-3\n",
    "\n",
    "net = monai.networks.nets.UNet(\n",
    "    dimensions=2,\n",
    "    in_channels=1,\n",
    "    out_channels=1,\n",
    "    channels=(16, 32, 64, 128, 256),\n",
    "    strides=(2, 2, 2, 2),\n",
    "    num_res_units=2,\n",
    ")\n",
    "\n",
    "\n",
    "def fake_loss(y_pred,y):\n",
    "    return (y_pred[0]+y).sum()\n",
    "\n",
    "\n",
    "def fake_data_stream():\n",
    "    while True:\n",
    "        yield torch.rand((10,1,64,64)),torch.rand((10,1,64,64))\n",
    "        \n",
    "        \n",
    "# 1 GPU\n",
    "opt = torch.optim.Adam(net.parameters(), lr)\n",
    "trainer=monai.engine.create_multigpu_supervised_trainer(net,opt,fake_loss,[torch.device('cuda:0')])\n",
    "trainer.run(fake_data_stream(),2,2)\n",
    "\n",
    "# all GPUs\n",
    "opt = torch.optim.Adam(net.parameters(), lr)\n",
    "trainer=monai.engine.create_multigpu_supervised_trainer(net,opt,fake_loss,None)\n",
    "trainer.run(fake_data_stream(),2,2)\n",
    "\n",
    "# CPU\n",
    "opt = torch.optim.Adam(net.parameters(), lr)\n",
    "trainer=monai.engine.create_multigpu_supervised_trainer(net,opt,fake_loss,[])\n",
    "trainer.run(fake_data_stream(),2,2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
